{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "3ede0360-00f4-404e-b0d2-4a83cc385654",
   "metadata": {
    "id": "3ede0360-00f4-404e-b0d2-4a83cc385654"
   },
   "source": [
    "🔗 Ensemble Model\n",
    "---\n",
    "We’ll reuse core components built earlier:\n",
    "\n",
    "- A fine-tuned LLaMA model\n",
    "- An XGBoost regression model, stored in Hugging Face\n",
    "- A ChromaDB vector store, stored on Google Drive and also available on AWS S3\n",
    "- A GPT-4o mini + RAG pipeline\n",
    "\n",
    "We'll run all three models on the same test data, gather their predictions, and train a Linear Regression Ensemble. The ensemble learns how to combine these predictions to output a more accurate final price.\n",
    "\n",
    "Once trained, we'll save the ensemble as ensemble_model.pkl, ready for later use.\n",
    "\n",
    "- 🧑‍💻 Skill Level: Advanced\n",
    "- ⚙️ Hardware: ⚠️ GPU required (use Google Colab)\n",
    "- 🛠️ Requirements: \n",
    "\n",
    "    - 🔑 Hugging Face Token and OpenAI Key — must be set in Google Colab secrets or .env files if you are running with your own GPU\n",
    "    - completion of Part 9 of [this series of notebooks](https://github.com/lisekarimi/lexo)\n",
    "- 🎯 Task: Train and save the Ensemble Model\n",
    "\n",
    "---\n",
    "📢 Find more LLM notebooks on my [GitHub repository](https://github.com/lisekarimi/lexo)"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "mzYB4XYQeWRQ",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "mzYB4XYQeWRQ",
    "outputId": "f474ce9b-09fb-4a47-93d7-273fe2d2ba10"
   },
   "outputs": [],
   "source": [
    "# Install required packages in Google Colab\n",
    "%pip install -q tqdm huggingface_hub numpy sentence-transformers datasets chromadb xgboost peft torch bitsandbytes"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "b3caecd1-8712-4acd-80b5-e8059c16f43f",
   "metadata": {
    "id": "b3caecd1-8712-4acd-80b5-e8059c16f43f"
   },
   "outputs": [],
   "source": [
    "# imports\n",
    "\n",
    "import os\n",
    "import re\n",
    "import zipfile\n",
    "import chromadb\n",
    "import joblib\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import requests\n",
    "import torch\n",
    "from datasets import load_dataset\n",
    "from google.colab import userdata\n",
    "from huggingface_hub import HfApi, hf_hub_download, login\n",
    "from openai import OpenAI\n",
    "from peft import PeftModel\n",
    "from sentence_transformers import SentenceTransformer\n",
    "from sklearn.linear_model import LinearRegression\n",
    "from sklearn.metrics import r2_score, mean_squared_error\n",
    "from sklearn.metrics import r2_score\n",
    "from tqdm import tqdm\n",
    "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "05d9523f-b6c9-4132-bd2b-6712772b3cd2",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "05d9523f-b6c9-4132-bd2b-6712772b3cd2",
    "outputId": "7077320e-43e2-4b03-ca7d-e7ea9a3407f8"
   },
   "outputs": [],
   "source": [
    "# Mount Google Drive to access saved ChromaDB and XGBoost model files\n",
    "\n",
    "from google.colab import drive\n",
    "drive.mount(\"/content/drive\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "z9735RD_TUHw",
   "metadata": {
    "id": "z9735RD_TUHw"
   },
   "outputs": [],
   "source": [
    "# Load from Colab's secure storage\n",
    "\n",
    "openai_api_key = userdata.get(\"OPENAI_API_KEY\")\n",
    "openai = OpenAI(api_key=openai_api_key)\n",
    "\n",
    "hf_token = userdata.get(\"HF_TOKEN\")\n",
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "DtswsfBQxxJF",
   "metadata": {
    "id": "DtswsfBQxxJF"
   },
   "outputs": [],
   "source": [
    "# Configuration\n",
    "\n",
    "HF_USER = \"lisekarimi\"\n",
    "ROOT = \"/content/drive/MyDrive/snapr\"\n",
    "os.makedirs(ROOT, exist_ok=True)\n",
    "\n",
    "api = HfApi(token=hf_token)\n",
    "REPO_NAME = \"smart-deal-finder-models\"\n",
    "REPO_ID = f\"{HF_USER}/{REPO_NAME}\""
   ]
  },
  {
   "cell_type": "markdown",
   "id": "qByarIFiTYa1",
   "metadata": {
    "id": "qByarIFiTYa1"
   },
   "source": [
    "### 📥 Load Test Dataset"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "f9ca3e34",
   "metadata": {},
   "outputs": [],
   "source": [
    "# #If you face NotImplementedError: Loading a dataset cached in a LocalFileSystem is not supported run:\n",
    "# %pip install -U datasets"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0eKakxSFTVcA",
   "metadata": {
    "id": "0eKakxSFTVcA"
   },
   "outputs": [],
   "source": [
    "DATASET_NAME = f\"{HF_USER}/pricer-data\"\n",
    "dataset = load_dataset(DATASET_NAME)\n",
    "test = dataset[\"test\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "cWqvs8JRTggE",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 110
    },
    "id": "cWqvs8JRTggE",
    "outputId": "bf7f0113-de82-422a-aaec-54efbb2b9d16"
   },
   "outputs": [],
   "source": [
    "# Format description function (no price in text)\n",
    "def description(item):\n",
    "    text = item[\"text\"].replace(\n",
    "        \"How much does this cost to the nearest dollar?\\n\\n\", \"\"\n",
    "    )\n",
    "    text = text.split(\"\\n\\nPrice is $\")[0]\n",
    "    return f\"passage: {text}\"\n",
    "\n",
    "\n",
    "description(test[0])"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "alpkYSc2UX0n",
   "metadata": {
    "id": "alpkYSc2UX0n"
   },
   "source": [
    "### 📥 Load Models and ChromaDB"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "pjPBEgXqmHOA",
   "metadata": {
    "id": "pjPBEgXqmHOA"
   },
   "outputs": [],
   "source": [
    "# ChromaDB\n",
    "\n",
    "CHROMA_PATH = f\"{ROOT}/chroma\"\n",
    "COLLECTION_NAME = \"price_items\"\n",
    "CHROMA_ZIP_URL = \"https://aiprojects-lise-karimi.s3.eu-west-3.amazonaws.com/smart-deal-finder/chroma.zip\"\n",
    "\n",
    "# Download and unzip if CHROMA_PATH doesn't exist\n",
    "if not os.path.exists(CHROMA_PATH):\n",
    "    os.makedirs(CHROMA_PATH, exist_ok=True)\n",
    "    r = requests.get(CHROMA_ZIP_URL)\n",
    "    with open(\"/tmp/chroma.zip\", \"wb\") as f:\n",
    "        f.write(r.content)\n",
    "    with zipfile.ZipFile(\"/tmp/chroma.zip\", \"r\") as zip_ref:\n",
    "        zip_ref.extractall(CHROMA_PATH)\n",
    "\n",
    "client = chromadb.PersistentClient(path=CHROMA_PATH)\n",
    "collection = client.get_or_create_collection(name=COLLECTION_NAME)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8fi1BS71XCv1",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 337,
     "referenced_widgets": [
      "c60f1153084a493ea31fac10bf986aef",
      "6de41ac188dd48aea5d30a90bc52c38c",
      "d2b4cdcaef6a4c41972f8c96af2814ab",
      "2180dfb4e6e74df5bf9985c481b6e420",
      "dcff84c8c3bf4f4bae334e0484207d10",
      "ca1ed709ecaa4a0e8ac96ffb930e6613",
      "e5297e7d36334c57aece043f62c79841",
      "5fdca4e0987a4788983c418941711d7e",
      "200e8e9b0df84affa177567243bf18d1",
      "ef6e6dcff8b444bba62c1f76e1127d7c",
      "7c94357c0d4e444489e8d47d2151437b",
      "4ffdcd2ec96046ffb5121def27c95c9d",
      "dd47ad1efe46496cb096a7714cf27c19",
      "194ad4f8707b4d288e88cdbdfa33605d",
      "c8a05ae3f5854f24998cc615a8849c88",
      "b5ec411e72a946f8b4a470de5827c949",
      "1a0be25b030d43858cee804da65d67a1",
      "88d1f6a56f9b4a50854aee82c0945cc9",
      "73e5967ae96942e080f3b05638583bc8",
      "8434bfa06abf42c98e8ffb0e7b83c9f9",
      "43d872e632da4d9883ea3d71dc91bdf9",
      "86729f54df1b4967b2730b48f84a98aa",
      "83fcebcf2b2c4213835334a998ba91e9",
      "4ef06b10bfcb418d85534a8b73688eff",
      "88c466cc89234d8f9f21147882fc5faf",
      "f87e958c639544c0b925646fc28c4604",
      "a52988b97dff4759a456398ecea1eaaf",
      "c1116a13be86401bbbf6e51de0df7d12",
      "d6ed27ce322748d29ed864808f619ee3",
      "4a6fbedd3333496081695800cae8bdda",
      "a7badc083fb34e69bd6f27bc9a805e7b",
      "a78bcdbac2f74c72938d87c431f23e78",
      "1d627cf1043642a3815a2902f65b4ded",
      "3b8cc480ded24f66b03779fd25844670",
      "0ce0073368c64339b3c1f960861e4b56",
      "ffc973a4347943ebaa4ead16e04c05f6",
      "aafba411ee984946a3ec0760580b60b6",
      "0dd2501d917f48739b2817d598541660",
      "213ca3afc47945a68e28a6ae005c3b7d",
      "6355e004d7c34b969b2d2c6ccbc12620",
      "9359b873cb4a4187b67a1732d78c7534",
      "32f86a2b9e0547a6bc0a523ca3cfa088",
      "0f446cb8ee3147438ef1e98e665a2831",
      "64bb7ebae66d42f2a4d6a3039bf67d4b",
      "e2b51ee511234ff2bc2cf33227fe2088",
      "a76d3def06db41fab4ab2f077839d5bc",
      "fa9598b858c14024aaf15d1417e9683d",
      "df0bb9a9635643ebb679e115f45dde8e",
      "527c4d1987334e3e9b2aa0de7d0527a7",
      "0c6a889a9066484abbcb87b730d7e325",
      "80a1f4c902154f2184c38ef844a1cca7",
      "463c3cc65cd343108fe6049e4cde7142",
      "fb015ce2fd9b48d79db67f80181964b7",
      "07f46375dc594cd19ac5ab983083b2de",
      "451ca5f213544cc8b24de6b7d55602fc",
      "5bb2e645ea7741839e0f88ae484d94d2",
      "19d1353070f643e08364500e9b1c30e6",
      "fff94d7934cc4793876903d1c18efbfa",
      "f0413b6310ef4510bf493e6814fa162d",
      "a97c1662e64c44ee9f6e5be5617c07ec",
      "da02a5ab5fe44cc297ab3048509a99e1",
      "ee8c23aa2ca84b32a02a2500917559d8",
      "6635ef559f72485e9453f87b3921f954",
      "d9f2925a563d4d9fa332c15205f44d9b",
      "73ec7891d53149e7a072a0e310716178",
      "f7309076b36e4224acc42ade5d09bf37",
      "cd53294ff44e4955afbfbd4660563b58",
      "2ae42eb6385f43fab59f2bb56bb8a28e",
      "fc0f1abaaf054d0d93a27c7ee0f6630d",
      "c7a61078596c475784307480d26e3661",
      "5bd26e4ff28f4639b52aab848ada03d1",
      "ce710ef5ffa14cfe9842c63caebc81e5",
      "63061726d47940c395a00d5d01556f4f",
      "2b41598231d14f3ca6354c9543ec4351",
      "0df33079f63c45d39de21439289aa4a8",
      "8cc5d2eac9a64d68b72608bb5ae44c89",
      "dd33a409204c4610a08e44c3e82e00da",
      "b96b7ac71a6c48a9a6c888f2f34efea5",
      "794d71dd5b734a3cb5607fc31aaddd18",
      "0f6e7a2d9b8846178a7492e137d83bba",
      "657bb839f0ea40eb9873385cecd06fd0",
      "1ac174e8904943bb9a5e5483e58eef63",
      "a1c9714fc4ea48af83669481e89c58c7",
      "763a2d64c8e94ca1b0289264d9f868bc",
      "ef78cf15ab914b3fa95ff95a86ec7a99",
      "24350ebdc38a41e689f3e3b09dfc3e35",
      "b8131b3c4c4c4b20809af9b0e91dd006",
      "420c50ed8abe49ec9f4f2777e6cd2749",
      "fc3f2d2c33ee40f8850710c2f4ce331f",
      "90730ec699e84ddda2af799f8220e7a5",
      "01f1c4b3b434474dbf2212a05869354d",
      "912d6c1687324bb9b334bbf98a2b5b30",
      "dc4106a0020b4b9fa21cd12a44967f2e",
      "0084e537ffa74ab4a6d5f307b0916d2c",
      "0edeb9ca771c4ca2a9a678e0e8a91614",
      "b610d515ddb4405695e6972e45463194",
      "4af72cb05f284d42bca73fcb88904255",
      "86f34fb6325e4e878eee0be27946c88b",
      "bfc26f456f1d440bb80eedac1cb14967",
      "1f005c3cf7594275a37bb937a3c33db3",
      "4347e7d3db4d4cc3836a4e69db032f27",
      "809a4d0270dd4c05817ea224bb78ff5a",
      "b403a344e84342e4b076e64e829d7354",
      "6b3cbb0ac0b14e3fa193cd5cb3f8f521",
      "2d06aaf8d15b456b8fadeb54dd2ea73d",
      "74842e94dda648d18cd055220a3d2b39",
      "5a3818bde07841fbb6077bf20b7dec4b",
      "2a5224c8b3004d249a07297a2111493f",
      "ff02ac7f08974426b3f70b71e59ed5bb",
      "0dd550d3e39f42809fa16770231af7e5"
     ]
    },
    "id": "8fi1BS71XCv1",
    "outputId": "9256b509-1371-4bb3-bd84-98bb75725ac3"
   },
   "outputs": [],
   "source": [
    "# Embedding Model\n",
    "\n",
    "embedding_model = SentenceTransformer(\"intfloat/e5-small-v2\", device=\"cuda\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "zmwIbufXUzMo",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 483,
     "referenced_widgets": [
      "36b26207185e4e10a8c60f0f5918aa7c",
      "96a0c5044c824968a701e20319c8d037",
      "a567759a6e554c2eb9334559e880a56e",
      "dcc68ce908ac4cfc9ffdcd0333cce14b",
      "06d7ac4abde74e1496ef80b9e22cd148",
      "94ba3ecd90f84291a46462a51ba001b8",
      "00b94b9d1e4e4e428c7467b4f99e45af",
      "02219374302849bd93d7aba7f65ee42b",
      "82223db2023045e88dc3d9652bc6b183",
      "6e08a796d2c440c58d5e6bb20f39be16",
      "6ac04725da4f485f8d7efd738de0940a",
      "230a6add3adb4dbd9f0329412e3f1455",
      "db5f777e8ec143fe89a2e6296031c9c2",
      "4f87c3aba240471d96349a037777732e",
      "74b04b71ce0746a89d146c4044c85890",
      "582e3cd91e3048639ab48492b5ca4b15",
      "1d8e31a4418c49e0bfa28b399673f718",
      "ec17f9e6bbbc4584baf11c0e7c504f84",
      "19779fea3c8c4a5cbea953b059775110",
      "3c82450a6db649f59f36668d5521d203",
      "3447d41776ef406bbfbe3e6c277b82d2",
      "4195859f8457499ba8e61a9f1662931c",
      "6f80cc0bed9b42e89f9691675ff5a484",
      "dc39eee63325479a915101649ab04273",
      "99e1e2c842f845d1b2ded34736b60ad7",
      "20ff78d4df09401d8aee462b57c57a48",
      "1aff04abbd0b4f15bd58154b00591264",
      "9cf024b488154926861d695137268da1",
      "6470d6b03e3548d783ed252d128ec361",
      "230a79f3d05d4c9eac73ae6962cb5d2d",
      "adfe1cc7a6f94b0ab43170c75688374d",
      "b6d24621c29f4352862f37ae69f2d6ad",
      "303a6f79669e40d98ed2998f4f5e47cc",
      "b737799377c948a99edf34c014a105c4",
      "daf9616a687f406da1d9ee2bd147c850",
      "cfd34b617326486498a916531bae9a87",
      "201ea14e31a244a3aa2aeed2c12fb255",
      "8454d5f263eb45daa0e6f7db6aa3f92d",
      "f9787ee50d634421ae5f0325126dcb73",
      "c2200daaf2994ed9ae16587d8d52236c",
      "93a0ec9723074e9199d1f9db988c30dc",
      "b1cb687d58ba40f49961d5485a466ef7",
      "ccc580852d66401993d675c254832379",
      "7321b89aea1c4746901ef40548bcf056",
      "8fae24599e174689adbaa52a16270785",
      "01605d6fb53d4e56aca9a746b2c75566",
      "45cb352704fe41f1be0f01c61511323d",
      "881550789f4845dda8561f6b26aff204",
      "654bf64ec165449993b195209d75f4ba",
      "9eaa6e09335047d5987c0a6528d5e77e",
      "39331a837a644795bada1e2b034fe14d",
      "a339562553394755811bb7077a81843d",
      "4f58f5fa385f44e0ae09e2019294d597",
      "431cc587951845d6af39f3e4ab0f2f76",
      "ac879c2c923b423dabd6d0d60b12266d",
      "4a3eb0fc1d2d4606a8acb382085a57af",
      "f82bbcbd14ad48a78ebdfdfb43916bd1",
      "a634449526034dbdb945c4905f4edeba",
      "da4133a915ad4d449876a43468203842",
      "a56f3a61d1ba4011ab6fce4067fa8418",
      "46727d5afdf7487fb073e7e2d25cc75d",
      "1480c5a1a0ad4151a12d47bc22685f04",
      "5720fc31c90344908d9eeb49fe83df1b",
      "47bd422d424a47e48edd304773162082",
      "6a68ecb89be34255a0e0fc6db41c1f4d",
      "12b1b3c7f0914030ad756b676cc97962",
      "780f5b6ad91142f991a936b55219f61d",
      "45723f91352b49688469a95e7f47aa9d",
      "90bbf502500340a1993a957c27ad3d33",
      "dcbb25b2a082476d905bcf124a849322",
      "38e6bd6d64b54f9a8cdb4f40eaf41cde",
      "7544a101cad94a15a2f4eb5639d22525",
      "501184a0fc424a02a80aecd3f62fb9fd",
      "1cbf5d28bced46ac8712a4609b5a5867",
      "9ba6ecc0a422472681d8e61bdb32f87b",
      "8383700148dd44538ed81ec5a261b7b5",
      "740e930ff17c4f668818a8c762a5470e",
      "0d995d8da0464c9ca7b1b444c22de025",
      "bc8a5e6d27ba402996434f00918c8b0e",
      "75f225b1a6f845148361b029878b63ea",
      "c16b051d3cb44607b339770f5f8b6f2e",
      "7857ed1e0b0f45bdb48269fbad68653d",
      "e65ffc77bd6740c7924aa5b93297cb89",
      "3ba97a41b4654dc0bb9bcaaa685b4518",
      "6504931865a74cb5a80f2ac60da47430",
      "ff8d5791b13c440d81312a6b96c9592f",
      "c31bd8cb693b4e248a29f2ded032fb70",
      "c25b1a42547a422ba7597c99ca4ce249",
      "5c6338fcad9344e092f5077bf73c4910",
      "dca176a7a6ea4fe9b025f851976f436a",
      "611fc076771a4fcca5c46367b711d61a",
      "0cfde45e26cf4c05b67755c2274f2df8",
      "5d572d2f46ea484587e085c29318b616",
      "8ad59c1261844a06b7abecebd7b60377",
      "82d610cd077c47bd9efd609f2399c861",
      "07fd58c9d07144a7a0aacab6b8252125",
      "aa0932b4e66b4f33ba9f5237ea1470da",
      "0d867615a23a42988bb91b4f0d0cc942",
      "9c28ff7b0f5c421390ac1ccb899f093f",
      "0645e7ed6593410eaf9c9c0b25158667",
      "ca68b1dc60a343f9bd7298a63cadd556",
      "9597ec6b495c4298b87967ed3e4044db",
      "d7dad0ae58814124af1e92a078122736",
      "173575b8b5254537937206759d6b6262",
      "d1efdc10d36441d88cf7705e846bfbef",
      "3550d450f95f40eeae0c0d559ae9f4de",
      "773df79ed7b44f698cce98ca9ed802cc",
      "24a2b4b88e1d46488eabb9101536beac",
      "0d61c01e6cdf445b9474f9d759676edd",
      "9b3505aacd164a19b45aba89eee46378",
      "bd6fb8b066be46aaa7d457bf89257e54",
      "f1eca4e5d600407885264d340b4f47a7",
      "ef1e6a69995845e09781f76a38fced30",
      "c5f50067867a40b99cb9f312e8adc49f",
      "4e9cf63dbef041aba2c7f0b9c74466c8",
      "07baa025b3a14bf89d6f6b438b695bbd",
      "fb8609ef5b8d4653b25e52f853b7be1f",
      "2f90ac58752347319d1203b5e8765c0e",
      "d9815cabf324472a8eab585afabfa47e",
      "120687e04065424595571941d816a134",
      "a6d09159931f4d3d91a0647d9fa9d8ba",
      "833c755f7bf9479abeef0041a82a92ba",
      "56d848676e644c739e28730af99d69c6",
      "39654eca8add4c09815cb3e6a45616ec",
      "02aa70e064744a29af0a68aaca33c741",
      "7fd83f95cfef4b1dbd881be7083d7455",
      "cdf56625053b417fad2e64a0bed6725b",
      "4199341c09bd46fab8a3b649d0c8af7a",
      "334e67f38b8243aa9072f52a32e46080",
      "492cdb40ffbf444d8e256875663fc598",
      "655c6e0f21ff4e9db7f35522355d847c",
      "926c1be6e26f4d9eba332881f975ed38",
      "47641f7363be4252b9f5e53846bee057",
      "887f8a2b268541eab71804a44ba1479b",
      "5b6e78d1727e473ab3b66d6ff042aeca",
      "acc60a1210104049983341db3010be0a",
      "c028c37980e14b3ea07b1da6f558651e",
      "9d1085906e3548078e5e393a86337c3e",
      "259c86a51f4e4cab9648cc603fc25c7e",
      "8ce05076e77643a88b062687e2b24493",
      "8b3b7f947f4d4401bbca47d5720f7450",
      "9bf9dccee248425da698dbb4526fcad9",
      "b991477124184bf3b4397762649a6596",
      "fb139bcae29f49778bf172eb503c0668",
      "93ba52daa9aa4dee8da91bba6c7d0269",
      "fefedf36efc94ed287bdeceae698d5b5",
      "8288b87b06b34ba4b2c7a343d6cea827",
      "5a85d212a6b5468fbc10e6aeb0ad8bee",
      "22274b08e14c4c77a6223131779f6f48",
      "edc0e436da954a33bbea8e80629eb43c",
      "1d9b1c594680467f9c8a6682d8aeb2e7",
      "d82f940e0a8a478e8b1ee8f169f798fc",
      "ab1616f507594b27b898ede4504b4e39",
      "1fed11f1c7484251a2a7400627ad5f6a"
     ]
    },
    "id": "zmwIbufXUzMo",
    "outputId": "2acb6897-4c41-4447-e029-ffcc1b3b4da1"
   },
   "outputs": [],
   "source": [
    "# Fine Tuned Llama Model\n",
    "\n",
    "BASE_MODEL = \"meta-llama/Meta-Llama-3.1-8B\"\n",
    "FINETUNED_MODEL = \"ed-donner/pricer-2024-09-13_13.04.39\"\n",
    "REVISION = \"e8d637df551603dc86cd7a1598a8f44af4d7ae36\"\n",
    "\n",
    "# Quantization config (4-bit)\n",
    "quant_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    bnb_4bit_quant_type=\"nf4\",\n",
    ")\n",
    "\n",
    "# Load tokenizer\n",
    "tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL, trust_remote_code=True)\n",
    "tokenizer.pad_token = tokenizer.eos_token\n",
    "tokenizer.padding_side = \"right\"\n",
    "\n",
    "# Load base model\n",
    "base_model = AutoModelForCausalLM.from_pretrained(\n",
    "    BASE_MODEL, quantization_config=quant_config, device_map=\"auto\"\n",
    ")\n",
    "\n",
    "# Load fine-tuned model\n",
    "fine_tuned_model = PeftModel.from_pretrained(\n",
    "    base_model, FINETUNED_MODEL, revision=REVISION\n",
    ")\n",
    "\n",
    "# Align generation config\n",
    "fine_tuned_model.generation_config.pad_token_id = tokenizer.pad_token_id\n",
    "\n",
    "print(f\"Memory footprint: {fine_tuned_model.get_memory_footprint() / 1e6:.1f} MB\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "0IHiJNU7a4XC",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 49,
     "referenced_widgets": [
      "0264a3987fbf4040860ffa3fc47940d8",
      "06d1db35940b469797c39c653741ea36",
      "84c4d2fdaf734a559ee3eee09f1be295",
      "fddde0bfed544b18ba39bfaa40eb9e1b",
      "d40cc525cc28416cad4a45b3631798c9",
      "e1372af176154902b1f555f30c28c007",
      "5a1352c5ceb84320b14353b7aa21650d",
      "522d0ed9e705457e9c72d276e2a26dbd",
      "4de73aa76f044811990c379737a8e5c0",
      "9305e96697ab4854ac89a6636991101d",
      "b00e41d1051340fd904ba719111a907d"
     ]
    },
    "id": "0IHiJNU7a4XC",
    "outputId": "c68bc44e-6b15-46c3-c8d9-3f256f368317"
   },
   "outputs": [],
   "source": [
    "# XGBoost Trained Model\n",
    "\n",
    "MODEL_FILENAME = \"xgboost_model.pkl\"\n",
    "model_path = hf_hub_download(repo_id=REPO_ID, filename=MODEL_FILENAME, token=hf_token)\n",
    "xgb_model = joblib.load(model_path)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "76BhcPjWa6C5",
   "metadata": {
    "id": "76BhcPjWa6C5"
   },
   "source": [
    "### 📊 Model prediction collection"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "LgGmUKJxayZ6",
   "metadata": {
    "id": "LgGmUKJxayZ6"
   },
   "outputs": [],
   "source": [
    "def extract_tagged_price(output: str):\n",
    "    \"\"\"Extracts a float price from a string based on 'Price is $' keyword.\"\"\"\n",
    "    try:\n",
    "        contents = output.split(\"Price is $\")[1].replace(\",\", \"\")\n",
    "        match = re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", contents)\n",
    "        return float(match.group()) if match else 0.0\n",
    "    except Exception:\n",
    "        return 0.0"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "ggKf1nSQbAnv",
   "metadata": {
    "id": "ggKf1nSQbAnv"
   },
   "outputs": [],
   "source": [
    "def ft_llama_price(description: str):\n",
    "    prompt = (\n",
    "        f\"How much does this cost to the nearest dollar?\\n\\n{description}\\n\\nPrice is $\"\n",
    "    )\n",
    "    inputs = tokenizer(prompt, return_tensors=\"pt\").to(\"cuda\")\n",
    "\n",
    "    outputs = fine_tuned_model.generate(\n",
    "        **inputs, max_new_tokens=5, num_return_sequences=1\n",
    "    )\n",
    "\n",
    "    result = tokenizer.decode(outputs[0])\n",
    "    price = extract_tagged_price(result)\n",
    "    return price"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "_cWyYUd4Ub-K",
   "metadata": {
    "id": "_cWyYUd4Ub-K"
   },
   "outputs": [],
   "source": [
    "def xgboost_price(description: str):\n",
    "    vector = embedding_model.encode([description], normalize_embeddings=True)[0]\n",
    "    pred = xgb_model.predict([vector])[0]\n",
    "    return round(float(max(0, pred)), 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "3Skod8juXgnN",
   "metadata": {
    "id": "3Skod8juXgnN"
   },
   "outputs": [],
   "source": [
    "def gpt4o_price(item):\n",
    "    def get_embedding(text):\n",
    "        return embedding_model.encode([text], normalize_embeddings=True)\n",
    "\n",
    "    def find_similars(text):\n",
    "        results = collection.query(\n",
    "            query_embeddings=get_embedding(text).astype(float).tolist(), n_results=5\n",
    "        )\n",
    "        docs = results[\"documents\"][0]\n",
    "        prices = [m[\"price\"] for m in results[\"metadatas\"][0]]\n",
    "        return docs, prices\n",
    "\n",
    "    def format_context(similars, prices):\n",
    "        context = (\n",
    "            \"To provide some context, here are similar products and their prices:\\n\\n\"\n",
    "        )\n",
    "        for sim, price in zip(similars, prices):\n",
    "            context += f\"Product:\\n{sim}\\nPrice is ${price:.2f}\\n\\n\"\n",
    "        return context\n",
    "\n",
    "    def build_messages(description, similars, prices):\n",
    "        system_message = (\n",
    "            \"You are a pricing expert. \"\n",
    "            \"Given a product description and a few similar products with their prices, \"\n",
    "            \"estimate the most likely price. \"\n",
    "            \"Respond ONLY with a number, no words.\"\n",
    "        )\n",
    "        context = format_context(similars, prices)\n",
    "        user_prompt = (\n",
    "            \"Estimate the price for the following product:\\n\\n\"\n",
    "            + description\n",
    "            + \"\\n\\n\"\n",
    "            + context\n",
    "        )\n",
    "        return [\n",
    "            {\"role\": \"system\", \"content\": system_message},\n",
    "            {\"role\": \"user\", \"content\": user_prompt},\n",
    "            {\"role\": \"assistant\", \"content\": \"Price is $\"},\n",
    "        ]\n",
    "\n",
    "    docs, prices = find_similars(description(item))\n",
    "    messages = build_messages(description(item), docs, prices)\n",
    "    response = openai.chat.completions.create(\n",
    "        model=\"gpt-4o-mini\", messages=messages, seed=42, max_tokens=5\n",
    "    )\n",
    "    reply = response.choices[0].message.content\n",
    "    return float(\n",
    "        re.search(r\"[-+]?\\d*\\.\\d+|\\d+\", reply.replace(\"$\", \"\").replace(\",\", \"\")).group()\n",
    "        or 0\n",
    "    )"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "98bf0aed",
   "metadata": {},
   "source": [
    "### ✂️ Split dataset and process"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "8XQK5yrk8On4",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "8XQK5yrk8On4",
    "outputId": "ec379798-8b73-4e66-a517-a818845c8353"
   },
   "outputs": [],
   "source": [
    "print(\"Splitting entire dataset...\")\n",
    "np.random.seed(42)\n",
    "all_indices = list(range(len(test)))\n",
    "np.random.shuffle(all_indices)\n",
    "\n",
    "train_split_size = int(0.8 * len(all_indices))\n",
    "train_indices = all_indices[:train_split_size]  # 80% of total\n",
    "test_indices = all_indices[train_split_size:]  # 20% of total\n",
    "\n",
    "train_indices = train_indices[:250]  # First 250 from training split\n",
    "test_indices = test_indices[:50]  # First 50 from testing split"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "XN7P5fkkXfgP",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "XN7P5fkkXfgP",
    "outputId": "69f9d265-a402-48ab-a91e-8c6032ea4118"
   },
   "outputs": [],
   "source": [
    "# Process subset of TRAINING data\n",
    "ft_llama_preds_train = []\n",
    "gpt4omini_preds_train = []\n",
    "xgboost_preds_train = []\n",
    "true_prices_train = []\n",
    "\n",
    "for i in tqdm(train_indices):\n",
    "    item = test[i]\n",
    "    text = description(item)\n",
    "    true_prices_train.append(item[\"price\"])\n",
    "    ft_llama_preds_train.append(ft_llama_price(text))\n",
    "    gpt4omini_preds_train.append(gpt4o_price(item))\n",
    "    xgboost_preds_train.append(xgboost_price(text))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "1_6_atEgHnFR",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "1_6_atEgHnFR",
    "outputId": "956e4dcb-2300-44ab-a66b-9b1254216762"
   },
   "outputs": [],
   "source": [
    "print(\"True Prices:\", true_prices_train)\n",
    "print(\"FT-LLaMA Predictions:\", ft_llama_preds_train)\n",
    "print(\"GPT-4o-mini Predictions:\", gpt4omini_preds_train)\n",
    "print(\"XGBoost Predictions:\", xgboost_preds_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "ygJsuvtLtOdR",
   "metadata": {
    "id": "ygJsuvtLtOdR"
   },
   "source": [
    "Example :\n",
    "- True Prices: [245.0, 24.99, 302.4, 737.0, ...]\n",
    "- FT-LLaMA Predictions: [99.0, 53.0, 550.0, 852.0, ...]\n",
    "- GPT-4o-mini Predictions: [179.99, 97.0, 348.0, 769.0, ...]\n",
    "- XGBoost Predictions: [220.19, 59.85, 254.29, 335.76, 165.04, ...]"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "tYWMhTrXcA7x",
   "metadata": {
    "id": "tYWMhTrXcA7x"
   },
   "outputs": [],
   "source": [
    "# Create features for TRAINING data\n",
    "maxes_train = [\n",
    "    max(a, b, c)\n",
    "    for a, b, c in zip(ft_llama_preds_train, gpt4omini_preds_train, xgboost_preds_train)\n",
    "]\n",
    "means_train = [\n",
    "    np.mean([a, b, c])\n",
    "    for a, b, c in zip(ft_llama_preds_train, gpt4omini_preds_train, xgboost_preds_train)\n",
    "]\n",
    "\n",
    "# Create TRAINING dataframe\n",
    "X_train = pd.DataFrame(\n",
    "    {\n",
    "        \"FT_LLaMA\": ft_llama_preds_train,\n",
    "        \"GPT4oMini\": gpt4omini_preds_train,\n",
    "        \"XGBoost\": xgboost_preds_train,\n",
    "        \"Max\": maxes_train,\n",
    "        \"Mean\": means_train,\n",
    "    }\n",
    ")\n",
    "\n",
    "y_train = pd.Series(true_prices_train)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e1682cf0",
   "metadata": {},
   "source": [
    "### 🏋️Train the Ensemble Model"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "-WsFABEicOyo",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "-WsFABEicOyo",
    "outputId": "42ae6421-fb4e-4ae6-ab54-b075e311b94d"
   },
   "outputs": [],
   "source": [
    "np.random.seed(42)\n",
    "lr = LinearRegression()\n",
    "lr.fit(X_train, y_train)\n",
    "\n",
    "# Print feature coefficients\n",
    "feature_columns = X_train.columns.tolist()\n",
    "for feature, coef in zip(feature_columns, lr.coef_):\n",
    "    print(f\"{feature}: {coef:.2f}\")\n",
    "print(f\"Intercept={lr.intercept_:.2f}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "GnYPOslHFgGx",
   "metadata": {
    "id": "GnYPOslHFgGx"
   },
   "source": [
    "- FT_LLaMA: 0.52\n",
    "- GPT4oMini: 0.17\n",
    "- XGBoost: -0.31\n",
    "- Max: 0.45\n",
    "- Mean: 0.13\n",
    "- Intercept=-6.06\n",
    "\n",
    "---\n",
    "FT_LLaMA is the most influential model in the ensemble.\n",
    "\n",
    "Max prediction also has strong positive impact.\n",
    "\n",
    "GPT4oMini and Mean contribute less, but still add value.\n",
    "\n",
    "XGBoost has a negative coefficient, acting as a counterbalance.\n",
    "\n",
    "\n",
    "Overall: FT_LLaMA leads, max adds value, XGBoost corrects for overestimation—resulting in a balanced ensemble."
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "wyx39HEL9niI",
   "metadata": {
    "id": "wyx39HEL9niI"
   },
   "source": [
    "### 🔮 Prediction"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "W3F0nNBXlrUJ",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "W3F0nNBXlrUJ",
    "outputId": "1dbd9702-50cf-4d80-b8ab-9b2000dd3b10"
   },
   "outputs": [],
   "source": [
    "# Process subset of TEST data\n",
    "ft_llama_preds_test = []\n",
    "gpt4omini_preds_test = []\n",
    "xgboost_preds_test = []\n",
    "true_prices_test = []\n",
    "\n",
    "print(\"Processing TEST data (50 items)...\")\n",
    "for i in tqdm(test_indices):\n",
    "    item = test[i]\n",
    "    text = description(item)\n",
    "    true_prices_test.append(item[\"price\"])\n",
    "    ft_llama_preds_test.append(ft_llama_price(text))\n",
    "    gpt4omini_preds_test.append(gpt4o_price(item))\n",
    "    xgboost_preds_test.append(xgboost_price(text))\n",
    "\n",
    "# Create features for TEST data\n",
    "maxes_test = [\n",
    "    max(a, b, c)\n",
    "    for a, b, c in zip(ft_llama_preds_test, gpt4omini_preds_test, xgboost_preds_test)\n",
    "]\n",
    "means_test = [\n",
    "    np.mean([a, b, c])\n",
    "    for a, b, c in zip(ft_llama_preds_test, gpt4omini_preds_test, xgboost_preds_test)\n",
    "]\n",
    "\n",
    "# Create TEST dataframe\n",
    "X_test = pd.DataFrame(\n",
    "    {\n",
    "        \"FT_LLaMA\": ft_llama_preds_test,\n",
    "        \"GPT4oMini\": gpt4omini_preds_test,\n",
    "        \"XGBoost\": xgboost_preds_test,\n",
    "        \"Max\": maxes_test,\n",
    "        \"Mean\": means_test,\n",
    "    }\n",
    ")\n",
    "\n",
    "y_test = pd.Series(true_prices_test)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "mVn6AAGq96wm",
   "metadata": {
    "id": "mVn6AAGq96wm"
   },
   "source": [
    "### 🧪 Evaluation"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "y25l8rR791wG",
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/"
    },
    "id": "y25l8rR791wG",
    "outputId": "0a02a620-eb0d-46a6-8f54-1046c2394ab3"
   },
   "outputs": [],
   "source": [
    "# Evaluate on the test set\n",
    "print(\"Evaluating model...\")\n",
    "y_pred = lr.predict(X_test)\n",
    "r2 = r2_score(y_test, y_pred)\n",
    "print(f\"R² score: {r2:.4f}\")\n",
    "\n",
    "# Calculate RMSE\n",
    "rmse = np.sqrt(mean_squared_error(y_test, y_pred))\n",
    "print(f\"RMSE: {rmse:.2f}\")\n",
    "\n",
    "# Calculate MAPE\n",
    "mape = np.mean(np.abs((y_test - y_pred) / y_test)) * 100\n",
    "print(f\"MAPE: {mape:.2f}%\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "vHJLe6LNEBrB",
   "metadata": {
    "id": "vHJLe6LNEBrB"
   },
   "source": [
    "Evaluating model...\n",
    "- R² score: 0.7376\n",
    "- RMSE: 127.62\n",
    "- MAPE: 29.70%\n",
    "\n",
    "---\n",
    "\n",
    "- R² = 0.74: This is a solid R² value, indicating our model explains about 74% of the variance in the price data\n",
    "Generally, an R² above 0.7 is considered good for price prediction tasks\n",
    "- RMSE = 127.6: Average error; good if prices are in the thousands.\n",
    "- MAPE = 29.7%: This means our predictions are off by roughly 30% on average. Typical for price prediction, but there’s room for improvement.\n"
   ],
   "outputs": []
  },
  {
   "cell_type": "markdown",
   "id": "C6cEJ57WApkG",
   "metadata": {
    "id": "C6cEJ57WApkG"
   },
   "source": [
    "### 🚀 Push to HF"
   ],
   "outputs": []
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "_n7n_MnscS4r",
   "metadata": {
    "id": "_n7n_MnscS4r"
   },
   "outputs": [],
   "source": [
    "# Serialize Ensemble model locally for Hugging Face upload\n",
    "\n",
    "MODEL_DIR = os.path.join(ROOT, \"models\")\n",
    "MODEL_FILENAME = \"ensemble_model.pkl\"\n",
    "LOCAL_MODEL = os.path.join(MODEL_DIR, MODEL_FILENAME)\n",
    "\n",
    "os.makedirs(MODEL_DIR, exist_ok=True)\n",
    "joblib.dump(lr, LOCAL_MODEL)\n",
    "\n",
    "# Create the model repo if it doesn't exist\n",
    "api.create_repo(repo_id=REPO_ID, repo_type=\"model\", private=True, exist_ok=True)\n",
    "\n",
    "# Upload the saved model\n",
    "api.upload_file(\n",
    "    path_or_fileobj=LOCAL_MODEL,\n",
    "    path_in_repo=MODEL_FILENAME,\n",
    "    repo_id=REPO_ID,\n",
    "    repo_type=\"model\",\n",
    ")"
   ]
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}